{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<div class=\"align-center\">\n",
    "<a href=\"https://oumi.ai/\"><img src=\"https://oumi.ai/docs/en/latest/_static/logo/header_logo.png\" height=\"200\"></a>\n",
    "\n",
    "[![Documentation](https://img.shields.io/badge/Documentation-latest-blue.svg)](https://oumi.ai/docs/en/latest/index.html)\n",
    "[![Discord](https://img.shields.io/discord/1286348126797430814?label=Discord)](https://discord.gg/oumi)\n",
    "[![GitHub Repo stars](https://img.shields.io/github/stars/oumi-ai/oumi)](https://github.com/oumi-ai/oumi)\n",
    "<a target=\"_blank\" href=\"https://colab.research.google.com/github/oumi-ai/oumi/blob/main/notebooks/Oumi - Custom Judge.ipynb\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>\n",
    "</div>\n",
    "\n",
    "👋 Welcome to Open Universal Machine Intelligence (Oumi)!\n",
    "\n",
    "🚀 Oumi is a fully open-source platform that streamlines the entire lifecycle of foundation models - from [data preparation](https://oumi.ai/docs/en/latest/resources/datasets/datasets.html) and [training](https://oumi.ai/docs/en/latest/user_guides/train/train.html) to [evaluation](https://oumi.ai/docs/en/latest/user_guides/evaluate/evaluate.html) and [deployment](https://oumi.ai/docs/en/latest/user_guides/launch/launch.html). Whether you're developing on a laptop, launching large scale experiments on a cluster, or deploying models in production, Oumi provides the tools and workflows you need.\n",
    "\n",
    "🤝 Make sure to join our [Discord community](https://discord.gg/oumi) to get help, share your experiences, and contribute to the project! If you are interested in joining one of the community's open-science efforts, check out our [open collaboration](https://oumi.ai/community) page.\n",
    "\n",
    "⭐ If you like Oumi and you would like to support it, please give it a star on [GitHub](https://github.com/oumi-ai/oumi)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Custom Judge\n",
    "\n",
    "Our platform offers [Oumi judge](https://github.com/oumi-ai/oumi/blob/main/notebooks/Oumi%20-%20Oumi%20Judge.ipynb), a judge that can help you filter examples that are non-helpful, dishonest, or unsafe out of your training dataset. Alternatively, our platform enables you to define a custom judge that can label and filter user-defined attributes, as well as user templates and custom parsing logic for judgements. This notebook demonstates how to build a custom judge."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 📋 Prerequisites\n",
    "\n",
    "## Oumi Installation\n",
    "\n",
    "First, let's install Oumi. You can find more detailed instructions [here](https://oumi.ai/docs/en/latest/get_started/installation.html).\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%pip install oumi"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Problem Statement\n",
    "\n",
    "Suppose that, for the sake of this tutorial, you live in a country where bananas are strictly illegal and talking about them is a crime. You want to finetune a LLM with a training dataset, but you want to ensure that the resulting model will never mention bananas in its response. You decide to build a custom judge that labels all the examples that mention bananas (either directly or indirectly), so that you can filter them our of your data. "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Attribute Definition\n",
    "\n",
    "In order to define a custom judge, we first need to define the `attribute` it will judge. This requires the following definitions:\n",
    "1. Message: Structure of the `Message` that will be judged.\n",
    "2. System Instruction: Instruction to request that the model operates as a judge.\n",
    "3. Judgement: Structure of the `Message` that includes the judgement and how it can be parsed to extract the label.\n",
    "4. (Optional) Few-shot examples. "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### 1. Message (to be judged)\n",
    "\n",
    "Let's name the message to be judged `BananaMessage` and ensure it inherits from `TemplatedMessage`. This is a message intended to train a model, so it must consist of a request (by a user) and a reponse (by the model, i.e., the AI assistant). We could alternatively only include the AI assistant's response, but combining it with the request will provide more context to the judge. \n",
    "- The required `role` parameter must be set to `USER`. This is the default setting for the messages to be judged.\n",
    "- The parameters `request` and `response` are user-defined and optional. We can have as many user-defined parameters as we want and name them any way we want, in order to capture all information that must be embedded in each message. \n",
    "- The required `template` parameter (a [Jinja2 Template](https://pypi.org/project/Jinja2/)) describes how the user-defined parameters (`request` and `response` in our specific example) will be combined to produce the message to be judged. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from oumi.core.types.conversation import Role, TemplatedMessage\n",
    "\n",
    "\n",
    "class BananaMessage(TemplatedMessage):\n",
    "    role: Role = Role.USER\n",
    "    request: str\n",
    "    response: str\n",
    "    template: str = \"<request>{{request}}</request><response>{{response}}</response>\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can generate a `BananaMessage` and inspect it as follows:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<request>What is your favorite fruit?</request><response>Banana!</response>\n"
     ]
    }
   ],
   "source": [
    "my_message = BananaMessage(request=\"What is your favorite fruit?\", response=\"Banana!\")\n",
    "print(my_message.content)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### 2. System Instruction\n",
    "\n",
    "Next, we must define a message, to be used as a System Instruction, which will instruct the model to operate as a judge. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "judge_SI = (\n",
    "    \"Act as a judge and try to identify the illegal word 'banana' in the text. \"\n",
    "    \"Specifically, if bananas are mentioned in the response, either directly \"\n",
    "    \"or indirectly (for example: by describing how they look rather than \"\n",
    "    \"explicitly mentioning the word 'bananas'), your answer should be 'True'. \"\n",
    "    \"If there is absolutely no mention of bananas, your answer should be 'False'. \"\n",
    "    \"Only respond with 'True' or 'False', no other information is nessesary.\"\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### 3. Judgement\n",
    "\n",
    "Next, we must define how the answer from the judge should be structured and how to parse it to extract the label. For this, we define the `BananaJudgement` class, which must inherit from `BaseJudgeOutput`. This class is also a `TemplatedMessage` and already includes the variable `raw_judgement`; this is where the judge's reponse will be written to. Similarly to `BananaMessage`, we can define a template in XML format (or whatever format makes sense for our use case) to easily extract the `label` from the judge's response (`raw_judgement`). But, for simplicity, we set no template (no special characters or words to wrap the `raw_judgement`), which means that we teach the judge to respond directly with the label. This is consistent with our `judge_SI` (above), where we request that the model only respond with `True` or `False`.\n",
    "\n",
    "Note: If, for instance, we wanted to also include an explanation, we would need to define an additional `explanation` field, combine it with `raw_judgement` in the `template`, and finally define the `_transform_model_output` for our custom judge. `_transform_model_output` describes how to extract the label (and the other custom fields such as our `explanation`) from the `raw_judgement`. See [Oumi judge](https://github.com/oumi-ai/oumi/blob/main/src/oumi/judges/oumi_judge.py) for reference."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "from oumi.judges.base_judge import BaseJudgeOutput\n",
    "\n",
    "\n",
    "class BananaJudgement(BaseJudgeOutput):\n",
    "    role: Role = Role.ASSISTANT\n",
    "    template: str = \"{{ raw_judgement }}\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### 4. Few-shot examples\n",
    "\n",
    "Finally let's (optionally) define few-shot examples, which help our custom judge understand its task, as well as the format it should respond with. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "few_shot_examples = [\n",
    "    BananaMessage(\n",
    "        request=\"How does your favorite fruit look like?\",\n",
    "        response=\"It's curved and yellow with a thick skin and soft sweet flesh\",\n",
    "    ),\n",
    "    BananaJudgement(\n",
    "        raw_judgement=\"True\",\n",
    "    ),\n",
    "]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### Putting everything together\n",
    "\n",
    "We can now define an attribute (`JudgeAttribute`) to describe to our custom judge how to judge, as follows. \n",
    "\n",
    "Note that we have set the `value_type` (i.e., the type of the `label` extracted from the judge's response) as `BOOL`. Other options are `CATEGORICAL` and `LIKERT_5`. For more details see our [judge config](https://github.com/oumi-ai/oumi/blob/main/src/oumi/core/configs/judge_config.py)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "from oumi.core.configs.judge_config import JudgeAttribute, JudgeAttributeValueType\n",
    "\n",
    "banana_attribute = JudgeAttribute(\n",
    "    name=\"banana_attribute\",\n",
    "    system_prompt=judge_SI,\n",
    "    examples=few_shot_examples,\n",
    "    value_type=JudgeAttributeValueType.BOOL,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Judge Config\n",
    "\n",
    "After defining the attribute(s) that our judge will label (`banana_attribute`), we also need to define the underlying model that we will use for inference. Specifically, we need to provide the `ModelParams` and `GenerationConfig`. For more details on these, please refer to our [Oumi judge](https://github.com/oumi-ai/oumi/blob/main/notebooks/Oumi%20-%20Oumi%20Judge.ipynb) notebook. Once we define these parameters, we have a `JudgeConfig` that fully describes the configuration of our judge."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "from oumi.core.configs import GenerationParams, JudgeConfig, ModelParams\n",
    "\n",
    "banana_config = JudgeConfig(\n",
    "    attributes={\"banana_attribute\": banana_attribute},\n",
    "    model=ModelParams(model_name=\"HuggingFaceTB/SmolLM2-135M-Instruct\"),\n",
    "    generation=GenerationParams(max_new_tokens=1024),\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Judge Definition\n",
    "\n",
    "The final step is to define our custom judge (`BananaJudge`) which should inherit from `BaseJudge`. There are 3 functions that we can optionally define, which describe how the judge will convert its input data and/or the judgment.\n",
    "\n",
    "- `_transform_conversation_input`: Defines how to convert our input (`BananaMessage`) if this is provided as `oumi.core.types.turn.Conversation`. This function is not nessesary in our current implementation, since our input is a `TemplatedMessage`.\n",
    "- `_transform_dict_input`: Defines how to convert our input (`BananaMessage`) if this is provided as `dict`. Again, this function is not nessesary in our current implementation, since our input is a `TemplatedMessage`.\n",
    "- `_transform_model_output`: Defines how to parse `raw_judgement` to multiple user-defined fields (which can be useful to the user to better undertand the label) and how to extract the `label`. Since in our implementation we train the judge to directly respond with the `label`, this function is not really needed. All we need to do is \"pass through\" the `model_output` to `raw_judgement`. \n",
    "\n",
    "Once `BananaJudge` is defined, we can instantiate it, using our `JudgeConfig` (`banana_config`). "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[2025-01-16 15:48:03,229][oumi][rank0][pid:79301][MainThread][INFO]][models.py:174] Building model using device_map: auto (DeviceRankInfo(world_size=1, rank=0, local_world_size=1, local_rank=0))...\n",
      "[2025-01-16 15:48:03,335][oumi][rank0][pid:79301][MainThread][INFO]][models.py:244] Using model class: <class 'transformers.models.auto.modeling_auto.AutoModelForCausalLM'> to instantiate model.\n"
     ]
    }
   ],
   "source": [
    "from oumi.judges.base_judge import BaseJudge\n",
    "\n",
    "\n",
    "class BananaJudge(BaseJudge):\n",
    "    def _transform_conversation_input(self, conversation):\n",
    "        raise NotImplementedError\n",
    "\n",
    "    def _transform_dict_input(self, raw_input):\n",
    "        raise NotImplementedError\n",
    "\n",
    "    def _transform_model_output(self, model_output):\n",
    "        return BananaJudgement(raw_judgement=model_output)\n",
    "\n",
    "\n",
    "my_banana_judge = BananaJudge(config=banana_config)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "### Judge Inference\n",
    "\n",
    "We can now call our `BananaJudge`'s `judge` method to judge a list of messages (i.e., our training dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Generating Model Responses: 100%|██████████| 3/3 [00:02<00:00,  1.25it/s]\n"
     ]
    }
   ],
   "source": [
    "training_dataset = [\n",
    "    BananaMessage(\n",
    "        request=\"Do you like apples?\",\n",
    "        response=\"Not as much as I like bananas and kiwis.\",\n",
    "    ),\n",
    "    BananaMessage(\n",
    "        request=\"What did you eat earlier?\",\n",
    "        response=\"I ate a yellow tropical fruit, similar to plantain\",\n",
    "    ),\n",
    "    BananaMessage(\n",
    "        request=\"Do you like hiking?\",\n",
    "        response=\"I love all sorts of sports.\",\n",
    "    ),\n",
    "]\n",
    "\n",
    "judge_output = my_banana_judge.judge(training_dataset)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Inspecting the judgments\n",
    "\n",
    "Let's inspect the judge's output. The first two conversations mention (explicitly and implicitly respectively) bananas, thus they should be filtered. Our judge indicates this by setting the `label` to `True`. The third conversation is related to sports; thus our judge sets the `label` to `False`, indicating that it is safe to train with it. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Input: <request>Do you like apples?</request><response>Not as much as I like bananas and kiwis.</response>\n",
      "Judgement: True\n",
      "Input: <request>What did you eat earlier?</request><response>I ate a yellow tropical fruit, similar to plantain</response>\n",
      "Judgement: True\n",
      "Input: <request>Do you like hiking?</request><response>I love all sorts of sports.</response>\n",
      "Judgement: False\n"
     ]
    }
   ],
   "source": [
    "def inspect_judge_output(judge_output, training_dataset):\n",
    "    \"\"\"Prints the judge output in a human-readable format.\"\"\"\n",
    "    for conversation, judgement in zip(training_dataset, judge_output):\n",
    "        print(\"Input:\", conversation.content)\n",
    "        print(\"Judgement:\", judgement[\"banana_attribute\"][\"label\"])\n",
    "\n",
    "\n",
    "inspect_judge_output(judge_output, training_dataset)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Conclusion\n",
    "\n",
    "Thanks to your judge's hard work, your trained model's conversations will be banana-free! Remember that the judge is as powerful as the underlying model (Qwen2 0.5B in this tutorial), so if the quality of judgments seems unsatisfactory, consider experimenting with larger models. "
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "oumi",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
